"""Generate plots from json files."""

import json
import os
from typing import List, Tuple

import matplotlib.pyplot as plt

COLORS = {"FedAvg": "red", "Floco": "blue", r"Floco$^{+}$": "green"}

# Get the current working directory
DIR = os.path.dirname(os.path.abspath(__file__))


def read_from_results(path: str) -> Tuple[str, List[float], List[float], str, str, int]:
    """Load the json file with recorded configurations and results."""
    with open(path, "r", encoding="UTF-8") as fin:
        data = json.load(fin)
        algorithm = data["run_config"]["algorithm"]
        pers_lamda = data["run_config"]["pers_lamda"]
        federated_accuracies = []
        centralized_accuracies = []
        for res in data["round_res"]:
            if "federated_evaluate_accuracy" in res:
                federated_accuracies.append(res["federated_evaluate_accuracy"] * 100)
            if "centralized_accuracy" in res:
                centralized_accuracies.append(res["centralized_accuracy"] * 100)
        read_dataset = data["run_config"]["dataset"]
        read_split = data["run_config"]["dataset-split"]

        return (
            algorithm,
            federated_accuracies,
            centralized_accuracies,
            read_dataset,
            read_split,
            pers_lamda,
        )


def make_plot(dir_path: str, dataset: str, split_name: str, plt_title: str) -> None:
    """Given a directory with json files, generate a plot using the provided title."""
    _, ax = plt.subplots(1, 2, figsize=(8, 3))

    results = {}  # Dictionary to store results based on algorithm

    with os.scandir(dir_path) as files:
        for file in files:
            file_name = os.path.join(dir_path, file.name)
            (
                algorithm,
                federated_accuracies,
                centralized_accuracies,
                read_dataset,
                read_split,
                pers_lamda,
            ) = read_from_results(file_name)
            if read_dataset != dataset or read_split != split_name:
                continue
            if algorithm == "Floco" and pers_lamda > 0:
                algorithm = r"Floco$^{+}$"
            rounds = list(range(1, len(federated_accuracies) + 1))
            results[algorithm] = (rounds, federated_accuracies, centralized_accuracies)
            if algorithm in results:
                rounds, federated_accuracies, centralized_accuracies = results[
                    algorithm
                ]
                print(f"Max accuracy ({algorithm}): {max(federated_accuracies):.2f}")
                ax[0].plot(
                    # rounds,
                    centralized_accuracies[
                        1:
                    ],  # Skip the first value to match indexing
                    color=COLORS[algorithm],
                    label=algorithm,
                )
                ax[1].plot(
                    # rounds,
                    federated_accuracies,
                    color=COLORS[algorithm],
                    label=algorithm,
                )
    ax[1].legend()
    ax[0].set_xlabel("Rounds")
    ax[0].set_ylabel("Accuracy")
    ax[1].set_xlabel("Rounds")
    ax[1].set_ylabel("Accuracy")
    ax[0].set_title("Centralized Test Accuracy")
    ax[1].set_title("Federated Test Accuracy")
    plt.tight_layout()
    save_path = os.path.join("_static/", f"{'_'.join(plt_title.split(' '))}.png")
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()


if __name__ == "__main__":
    # Plot results generated by the baseline.
    # Combine them into a full file path.
    DATASET = "CIFAR10"
    splits = ["Dirichlet", "Fold"]
    for split in splits:
        res_dir = os.path.join(DIR, "../results/")
        title = f"{DATASET} {split}"
        make_plot(res_dir, DATASET, split, plt_title=title)
